#include <c10/util/Exception.h>
#include <c10/xpu/XPUFunctions.h>

#include <vector>

namespace c10::xpu {
namespace {

/*
 * Note [Device Management]
 *
 * An Intel GPU device qualifies as a type of SYCL device. This classification
 * allows for the runtime querying of Intel GPU device information through the
 * SYCL runtime library.
 *
 * Device status is managed through a SYCL device pool, with SYCL devices
 * determined at runtime. There's currently a SYCL device pool that is lazily
 * created and only initialized once, ensuring thread-local safety. Each device
 * within the device pool shares the same default context.
 *
 * In certain scenarios, GPU devices may reside on separate SYCL platforms. For
 * instance, on Windows, an integrated GPU (iGPU) and a discrete GPU (dGPU) may
 * exist on different platforms. Since sycl::context cannot span across multiple
 * platforms, creating a single default context that includes both becomes
 * infeasible.
 *
 * To address this limitation, we prioritize the enumeration of dGPU. The device
 * enumeration logic is as follows:
 * 1. Identify the first Level Zero (L0) platform that contains at least one
 *    dGPU and enumerate all dGPUs on that platform.
 * 2. If no dGPU is found, identify the first L0 platform containing at least
 *    one iGPU and enumerate all iGPUs on that platform.
 * 3. If neither dGPUs nor iGPUs are found, conclude that no GPUs are available.
 */
thread_local DeviceIndex curDeviceIndex = 0;

struct DevicePool {
  std::vector<std::unique_ptr<sycl::device>> devices;
  std::unique_ptr<sycl::context> context;
} gDevicePool;

void enumDevices(std::vector<std::unique_ptr<sycl::device>>& devices) {
  // See Note [Device Management] for more details.
  auto platform_list = sycl::platform::get_platforms();
  auto is_igpu = [](const sycl::device& device) {
    // Generally, iGPUs share a unified memory subsystem with the host.
    return device.get_info<sycl::info::device::host_unified_memory>();
  };

  // Check if a platform contains at least one GPU (either iGPU or dGPU).
  auto has_gpu = [&is_igpu](const sycl::platform& platform, bool check_igpu) {
    // Only consider platforms using the Level Zero backend.
    if (platform.get_backend() != sycl::backend::ext_oneapi_level_zero) {
      return false;
    }
    // Check if the platform contains at least one GPU.
    for (const auto& device : platform.get_devices()) {
      if (device.is_gpu() &&
          (check_igpu ? is_igpu(device) : !is_igpu(device))) {
        return true;
      }
    }
    // No GPU found on the platform.
    return false;
  };

  // Case 1: Platform with dGPU found. Most platforms with dGPU only have dGPU
  // or a combination of dGPU and iGPU.
  for (const auto& platform : platform_list) {
    // Find the first platform that contains at least one dGPU.
    if (has_gpu(platform, /*check_igpu=*/false)) {
      for (const auto& device : platform.get_devices()) {
        // Only add all dGPUs to the device list.
        if (device.is_gpu() && !is_igpu(device)) {
          devices.push_back(std::make_unique<sycl::device>(device));
        }
      }
      return; // Exit early since we already found a platform with dGPU.
    }
  }

  // Case 2: No dGPU found, but a platform with iGPU is available.
  for (const auto& platform : platform_list) {
    // Find the first platform that contains at least one iGPU.
    if (has_gpu(platform, /*check_igpu=*/true)) {
      for (const auto& device : platform.get_devices()) {
        // Add all iGPUs to the device list.
        if (device.is_gpu()) { // If the device is a GPU, it must be a iGPU.
          devices.push_back(std::make_unique<sycl::device>(device));
        }
      }
      return; // Exit early since we already found a platform with iGPU.
    }
  }

  // Case 3: No GPUs found (neither dGPU nor iGPU) - Do nothing.
}

inline void initGlobalDevicePoolState() {
  // Attempt to initialize XPU devices. If no device is found or the driver is
  // not installed correctly, issue a warning message instead of raising an
  // exception to avoid disrupting the user experience.
  try {
    // Enumerate all GPU devices and record them.
    enumDevices(gDevicePool.devices);
  } catch (const sycl::exception& e) {
    TORCH_WARN(
        "Failed to initialize XPU devices. The driver may not be installed, installed incorrectly, or incompatible with the current setup. ",
        "Please refer to the guideline (https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support) for proper installation and configuration.");
    return;
  }
  if (gDevicePool.devices.empty()) {
    TORCH_WARN("XPU device count is zero!");
    return;
  }
  // Ensures that the number of GPU devices does not exceed the maximum
  // allowable value for DeviceIndex.
  TORCH_CHECK(
      gDevicePool.devices.size() <= std::numeric_limits<DeviceIndex>::max(),
      "Too many XPU devices, DeviceIndex overflowed!");
  // Check each device's architecture and issue a warning if it is older than
  // the officially supported range (Intel GPUs starting from Arc (Alchemist)
  // series).
  namespace syclex = sycl::ext::oneapi::experimental;
  for (const auto& device : gDevicePool.devices) {
    auto architecture = device->get_info<syclex::info::device::architecture>();
    if (architecture < syclex::architecture::intel_gpu_acm_g10) {
      TORCH_WARN(
          "The detected GPU (",
          device->get_info<sycl::info::device::name>(),
          ") is not officially supported by PyTorch XPU. Running workloads on this device may result in unexpected behavior.\n",
          "For stable and fully supported execution, please use GPUs based on Intel Arc (Alchemist) series or newer.\n",
          "Refer to the hardware prerequisites for more information: ",
          "https://github.com/pytorch/pytorch/blob/main/docs/source/notes/get_start_xpu.rst#hardware-prerequisite");
    }
  }

  // The default context is utilized for each Intel GPU device, allowing the
  // retrieval of the context from any GPU device.
  const auto& platform = gDevicePool.devices[0]->get_platform();
  gDevicePool.context = std::make_unique<sycl::context>(
#if SYCL_COMPILER_VERSION >= 20250200
      platform.khr_get_default_context());
#else
      platform.ext_oneapi_get_default_context());
#endif
}

inline void initDevicePoolCallOnce() {
  auto static init_flag [[maybe_unused]] = [] {
    initGlobalDevicePoolState();
    return true;
  }();
}

void initDeviceProperties(DeviceProp* device_prop, DeviceIndex device) {
  using namespace sycl::info;
  using namespace sycl::ext;
  // Get raw sycl device associated with device index.
  auto& raw_device = *gDevicePool.devices[device];

  // Initialize the device properties associated with the specific device.
#define ASSIGN_DEVICE_PROP(property) \
  device_prop->property = raw_device.get_info<device::property>();

#define ASSIGN_EXT_DEVICE_PROP(property, aspect_tag, default_value)            \
  device_prop->property = raw_device.has(sycl::aspect::ext_intel_##aspect_tag) \
      ? raw_device.get_info<intel::info::device::property>()                   \
      : default_value;

#define ASSIGN_DEVICE_ASPECT(member) \
  device_prop->has_##member = raw_device.has(sycl::aspect::member);

#define ASSIGN_EXP_CL_ASPECT(member) \
  device_prop->has_##member =        \
      raw_device.ext_oneapi_supports_cl_extension("cl_intel_" #member);

#define ASSIGN_EXP_DEVICE_PROP(property) \
  device_prop->property =                \
      raw_device.get_info<oneapi::experimental::info::device::property>();

  AT_FORALL_XPU_DEVICE_PROPERTIES(ASSIGN_DEVICE_PROP);

  device_prop->platform_name =
      raw_device.get_info<device::platform>().get_info<platform::name>();

  AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(ASSIGN_EXT_DEVICE_PROP);

  AT_FORALL_XPU_DEVICE_ASPECT(ASSIGN_DEVICE_ASPECT);

  AT_FORALL_XPU_EXP_CL_ASPECT(ASSIGN_EXP_CL_ASPECT);

#if SYCL_COMPILER_VERSION >= 20250000
  AT_FORALL_XPU_EXP_DEVICE_PROPERTIES(ASSIGN_EXP_DEVICE_PROP);
#endif

  return;
}

} // anonymous namespace

sycl::device& get_raw_device(DeviceIndex device) {
  initDevicePoolCallOnce();
  check_device_index(device);
  return *gDevicePool.devices[device];
}

sycl::context& get_device_context() {
  initDevicePoolCallOnce();
  TORCH_CHECK(
      gDevicePool.context,
      "Device pool initialization failed, you might not have an XPU device.")
  return *gDevicePool.context;
}

void get_device_properties(DeviceProp* device_prop, DeviceIndex device) {
  initDevicePoolCallOnce();
  TORCH_CHECK(device_prop, "device_prop is an invalid pointer.");
  check_device_index(device);
  initDeviceProperties(device_prop, device);
}

DeviceIndex get_device_idx_from_pointer(void* ptr) {
  initDevicePoolCallOnce();
  TORCH_CHECK(ptr, "ptr is an invalid pointer.");
  auto type = sycl::get_pointer_type(ptr, get_device_context());
  TORCH_CHECK(
      type == sycl::usm::alloc::device, "ptr is not a device type pointer.");

  sycl::device raw_device = sycl::get_pointer_device(ptr, get_device_context());
  auto match_device = [raw_device](const auto& device) -> bool {
    return raw_device == *device;
  };
  auto it = std::find_if(
      gDevicePool.devices.begin(), gDevicePool.devices.end(), match_device);
  TORCH_CHECK(
      it != gDevicePool.devices.end(),
      "Can't find the pointer from XPU devices.");
  return static_cast<DeviceIndex>(
      std::distance(gDevicePool.devices.begin(), it));
}

DeviceIndex device_count() {
  initDevicePoolCallOnce();
  return static_cast<DeviceIndex>(gDevicePool.devices.size());
}

DeviceIndex device_count_ensure_non_zero() {
  auto count = device_count();
  // Zero gpus could produce a warning in `device_count` but we fail here.
  TORCH_CHECK(count, "No XPU devices are available.");
  return count;
}

DeviceIndex current_device() {
  initDevicePoolCallOnce();
  return curDeviceIndex;
}

void set_device(DeviceIndex device) {
  initDevicePoolCallOnce();
  check_device_index(device);
  curDeviceIndex = device;
}

c10::DeviceIndex exchange_device(c10::DeviceIndex to_device) {
  auto cur_device = current_device();
  if (to_device == cur_device) {
    return cur_device;
  }
  set_device(to_device);
  return cur_device;
}

c10::DeviceIndex maybe_exchange_device(c10::DeviceIndex to_device) {
  return exchange_device(to_device);
}

} // namespace c10::xpu
